#include "framework/unittest.hh"

#include "../core/plato_variable.hh"
#include "../test/struct.hh"

FIXTURE_BEGIN(TestPlatoVariable)

using namespace plato;

CASE(TestNumeric1) {
  auto domain = new_domain(0);
  auto int_ptr = domain->New<Int>();
  *int_ptr = 2;
  ASSERT_TRUE(int_ptr->type() == PlatoType::INT32);
  ASSERT_TRUE(int_ptr->parent() == 0);
  ASSERT_TRUE(*int_ptr == 2);
  *int_ptr %= 3;
  ASSERT_TRUE(*int_ptr == 2);
  *int_ptr += 1;
  ASSERT_TRUE(*int_ptr == 3);
  *int_ptr -= 1;
  ASSERT_TRUE(*int_ptr == 2);
  (*int_ptr)++;
  ASSERT_TRUE(*int_ptr == 3);
  (*int_ptr)--;
  ASSERT_TRUE(*int_ptr == 2);
  ++*int_ptr;
  ASSERT_TRUE(*int_ptr == 3);
  --*int_ptr;
  ASSERT_TRUE(*int_ptr == 2);

  *int_ptr *= 2;
  ASSERT_TRUE(*int_ptr == 4);
  *int_ptr /= 2;
  ASSERT_TRUE(*int_ptr == 2);

  ASSERT_TRUE(*(*int_ptr + *int_ptr) == 4);
  ASSERT_TRUE(*int_ptr + 2 == 4);
  ASSERT_TRUE(2 + *int_ptr == 4);

  ASSERT_TRUE(*(*int_ptr - *int_ptr) == 0);
  ASSERT_TRUE(*int_ptr - 2 == 0);
  ASSERT_TRUE(2 - *int_ptr == 0);

  ASSERT_TRUE(*int_ptr * 2 == 4);
  ASSERT_TRUE(*int_ptr / 2 == 1);
  ASSERT_TRUE(*int_ptr % 3 == 2);

  ASSERT_TRUE(2 * *int_ptr == 4);
  ASSERT_TRUE(2 / *int_ptr == 1);

  ASSERT_TRUE(*(*int_ptr % *int_ptr) == 0);

  ASSERT_EXCEPT(*int_ptr / 0);
  ASSERT_EXCEPT(*int_ptr % 0);
}

CASE(TestNumeric2) {
  auto domain = new_domain(0);
  auto int_ptr = domain->New<Int>();
  *int_ptr = 2;
  *int_ptr %= 3;
  ASSERT_TRUE(*int_ptr == 2);
}

CASE(TestNumeric3) {
  auto domain = new_domain(0);
  auto float_ptr = domain->New<Float>();
  *float_ptr = 2.0f;
  *float_ptr %= 3.0f;
  ASSERT_TRUE(*float_ptr == 2.0f);
  *float_ptr += 1.0f;
  ASSERT_TRUE(*float_ptr == 3.0f);
  *float_ptr -= 1.0f;
  ASSERT_TRUE(*float_ptr == 2.0f);
  (*float_ptr)++;
  ASSERT_TRUE(*float_ptr == 3.0f);
  (*float_ptr)--;
  ASSERT_TRUE(*float_ptr == 2.0f);
  ++(*float_ptr);
  ASSERT_TRUE((*float_ptr) == 3.0f);
  --(*float_ptr);
  ASSERT_TRUE((*float_ptr) == 2.0f);

  (*float_ptr) *= 2.0f;
  ASSERT_TRUE((*float_ptr) == 4.0f);
  (*float_ptr) /= 2.0f;
  ASSERT_TRUE((*float_ptr) == 2.0f);

  ASSERT_TRUE(*(*float_ptr + *float_ptr) == 4.0f);
  ASSERT_TRUE(*float_ptr + 2.0f == 4.0f);

  ASSERT_TRUE(*(*float_ptr - *float_ptr) == 0.0f);
  ASSERT_TRUE(*float_ptr - 2.0f == 0.0f);

  ASSERT_TRUE(*float_ptr * 2.0f == 4.0f);
  ASSERT_TRUE(*float_ptr / 2.0f == 1.0f);

  ASSERT_TRUE(2.0f * *float_ptr == 4.0f);
  ASSERT_TRUE(2.0f / *float_ptr == 1.0f);

  ASSERT_TRUE(*(*float_ptr % *float_ptr) == 0.0f);

  ASSERT_TRUE(*float_ptr % 3.0f == 2.0f);

  ASSERT_EXCEPT(*float_ptr / .0f);
  ASSERT_EXCEPT(*float_ptr % .0f);
}

CASE(TestNumeric4) {
  auto domain = new_domain(0);
  auto int_ptr = domain->New<Int>(PlatoVariableSyncType::NONE, 2);
  ASSERT_TRUE(*int_ptr == 2);
  *int_ptr %= 3;
  ASSERT_TRUE(*int_ptr == 2);
}

CASE(TestNumeric5) {
  auto domain = new_domain(0);
  auto int_ptr = domain->New<Int>(PlatoVariableSyncType::NONE, 1);
  *int_ptr = 2;
  ASSERT_TRUE(*int_ptr == 2);
  int_ptr->copy_default();
  ASSERT_TRUE(*int_ptr == 1);
}

CASE(TestArray1) {
  auto domain = new_domain(0);
  auto array_ptr = domain->New<Array<Int>>();
  *array_ptr->add() = 1;
  ASSERT_TRUE(*(*array_ptr)[0] == 1);
}

CASE(TestArray2) {
  auto domain = new_domain(0);
  auto array_ptr = domain->New<Array<Int>>();
  *array_ptr->add_default() = 1;
  *array_ptr->add_default() = 2;
  array_ptr->complete_prototype();
  *array_ptr->add() = 3;
  ASSERT_TRUE(*(*array_ptr)[2] == 3);
  ASSERT_TRUE(array_ptr->size() == 3);
  array_ptr->copy_default();
  ASSERT_TRUE(array_ptr->size() == 2);
}

CASE(TestMap1) {
  auto domain = new_domain(0);
  auto map_ptr = domain->New<Map<int, Int>>();
  *map_ptr->add(1) = 1;
  ASSERT_TRUE(*map_ptr->get(1) == 1);
}

CASE(TestMap2) {
  auto domain = new_domain(0);
  auto map_ptr = domain->New<Map<std::string, Int>>();
  *map_ptr->add("1") = 1;
  ASSERT_TRUE(*map_ptr->get("1") == 1);
}

CASE(TestMap3) {
  auto domain = new_domain(0);
  auto map_ptr = domain->New<Map<int, Int>>();
  map_ptr->add_default(1)->set_default(1);
  map_ptr->add_default(2)->set_default(2);
  map_ptr->complete_prototype();
  *map_ptr->add(3) = 3;
  ASSERT_TRUE(*map_ptr->get(3) == 3);
  ASSERT_TRUE(map_ptr->size() == 3);
  map_ptr->copy_default();
  ASSERT_TRUE(map_ptr->size() == 2);
  ASSERT_TRUE(*map_ptr->get(2) == 2);
}

CASE(TestSet1) {
  auto domain = new_domain(0);
  auto set_ptr = domain->New<Set<int>>();
  set_ptr->add(1);
  ASSERT_TRUE(set_ptr->has(1));
}

CASE(TestSet2) {
  auto domain = new_domain(0);
  auto set_ptr = domain->New<Set<std::string>>();
  set_ptr->add("123");
  ASSERT_TRUE(set_ptr->has("123"));
}

CASE(TestSet3) {
  auto domain = new_domain(0);
  auto set_ptr = domain->New<Set<int>>();
  set_ptr->add_default(1);
  set_ptr->add_default(2);
  set_ptr->complete_prototype();
  set_ptr->add(3);
  ASSERT_TRUE(set_ptr->has(3));
  set_ptr->copy_default();
  ASSERT_TRUE(!set_ptr->has(3));
  ASSERT_TRUE(set_ptr->size() == 2);
}

CASE(TestStruct1) {
  auto domain = new_domain(0);
  auto stru_ptr = domain->New<TestStruct>();
  *stru_ptr->a = 1;
  *stru_ptr->b = 1.0f;
  *stru_ptr->c = "123";
  *stru_ptr->d->add() = 4;
  stru_ptr->e->add(1);
  *stru_ptr->f->add(1) = 5;

  ASSERT_TRUE(*stru_ptr->a == 1);
  ASSERT_TRUE(*stru_ptr->b == 1.0f);
  ASSERT_TRUE(*stru_ptr->c == "123");
  ASSERT_TRUE(*(*stru_ptr->d)[0] == 4);
  ASSERT_TRUE(stru_ptr->e->has(1));
  ASSERT_TRUE(*stru_ptr->f->get(1) == 5);
}

CASE(TestStruct2) {
  auto domain = new_domain(0);
  auto stru_ptr = domain->New<TestStruct>();
  stru_ptr->a->set_default(1);
  stru_ptr->b->set_default(1.0f);
  stru_ptr->c->set_default("123");
  stru_ptr->d->add_default()->set_default(4);
  stru_ptr->e->add_default(1);
  stru_ptr->f->add_default(1)->set_default(5);

  stru_ptr->complete_prototype();

  ASSERT_TRUE(*stru_ptr->a == 1);
  ASSERT_TRUE(*stru_ptr->b == 1.0f);
  ASSERT_TRUE(*stru_ptr->c == "123");
  ASSERT_TRUE(*(*stru_ptr->d)[0] == 4);
  ASSERT_TRUE(stru_ptr->e->has(1));
  ASSERT_TRUE(*stru_ptr->f->get(1) == 5);
}

CASE(TestNumericSerial1) {
  auto domain1 = new_domain(0);
  auto int_ptr1 = domain1->New<Int>(PlatoVariableSyncType::SENDER);
  *int_ptr1 = 1;

  auto domain2 = new_domain(0);
  auto int_ptr2 = domain2->New<Int>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*int_ptr2 == 1);
}

CASE(TestNumericSerial2) {
  auto domain1 = new_domain(0);
  auto float_ptr1 = domain1->New<Float>(PlatoVariableSyncType::SENDER);
  *float_ptr1 = 1.0f;

  auto domain2 = new_domain(0);
  auto float_ptr2 = domain2->New<Float>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*float_ptr2 == 1.0f);
}

CASE(TestStringSerial1) {
  auto domain1 = new_domain(0);
  auto str_ptr1 = domain1->New<String>(PlatoVariableSyncType::SENDER);
  *str_ptr1 = "123";

  auto domain2 = new_domain(0);
  auto str_ptr2 = domain2->New<String>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*str_ptr2 == "123");
}

CASE(TestMapSerial1) {
  auto domain1 = new_domain(0);
  auto map_ptr1 = domain1->New<Map<int, Int>>(PlatoVariableSyncType::SENDER);
  *map_ptr1->add(1) = 2;

  auto domain2 = new_domain(0);
  auto map_ptr2 = domain2->New<Map<int, Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*map_ptr2->get(1) == 2);
}

CASE(TestMapSerial2) {
  auto domain1 = new_domain(0);
  auto map_ptr1 = domain1->New<Map<int, Int>>(PlatoVariableSyncType::SENDER);
  *map_ptr1->add(1) = 2;

  auto domain2 = new_domain(0);
  auto map_ptr2 = domain2->New<Map<int, Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*map_ptr2->get(1) == 2);

  map_ptr1->remove(1);
  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(!map_ptr2->get(1));
}

CASE(TestMapSerial3) {
  auto domain1 = new_domain(0);
  auto map_ptr1 = domain1->New<Map<int, Int>>(PlatoVariableSyncType::SENDER);
  *map_ptr1->add(1) = 2;

  auto domain2 = new_domain(0);
  auto map_ptr2 = domain2->New<Map<int, Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream().clear();

  domain1->serialize_all();

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*map_ptr2->get(1) == 2);
}

CASE(TestMapSerial4) {
  auto domain1 = new_domain(0);
  auto map_ptr1 = domain1->New<Map<std::string, Int>>(PlatoVariableSyncType::SENDER);
  *map_ptr1->add("1") = 2;

  auto domain2 = new_domain(0);
  auto map_ptr2 = domain2->New<Map<std::string, Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*map_ptr2->get("1") == 2);
}

CASE(TestArraySerial1) {
  auto domain1 = new_domain(0);
  auto arr_ptr1 = domain1->New<Array<Int>>(PlatoVariableSyncType::SENDER);
  *arr_ptr1->add() = 2;

  auto domain2 = new_domain(0);
  auto arr_ptr2 = domain2->New<Array<Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*(*arr_ptr2)[0] == 2);
}

CASE(TestArraySerial2) {
  auto domain1 = new_domain(0);
  auto arr_ptr1 = domain1->New<Array<Int>>(PlatoVariableSyncType::SENDER);
  *arr_ptr1->add() = 2;

  auto domain2 = new_domain(0);
  auto arr_ptr2 = domain2->New<Array<Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*(*arr_ptr2)[0] == 2);

  arr_ptr1->remove(0);
  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(arr_ptr2->size() == 0);
}

CASE(TestArraySerial3) {
  auto domain1 = new_domain(0);
  auto arr_ptr1 = domain1->New<Array<Int>>(PlatoVariableSyncType::SENDER);
  *arr_ptr1->add() = 2;

  auto domain2 = new_domain(0);
  auto arr_ptr2 = domain2->New<Array<Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream().clear();

  domain1->serialize_all();

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*(*arr_ptr2)[0] == 2);
}

CASE(TestArraySerial4) {
  auto domain1 = new_domain(0);
  auto arr_ptr1 = domain1->New<Array<String>>(PlatoVariableSyncType::SENDER);
  *arr_ptr1->add() = "2";

  auto domain2 = new_domain(0);
  auto arr_ptr2 = domain2->New<Array<String>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*(*arr_ptr2)[0] == "2");
}

CASE(TestSetSerial1) {
  auto domain1 = new_domain(0);
  auto set_ptr1 = domain1->New<Set<int>>(PlatoVariableSyncType::SENDER);
  set_ptr1->add(1);

  auto domain2 = new_domain(0);
  auto set_ptr2 = domain2->New<Set<int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(set_ptr2->has(1));
}

CASE(TestSetSerial2) {
  auto domain1 = new_domain(0);
  auto set_ptr1 = domain1->New<Set<int>>(PlatoVariableSyncType::SENDER);
  set_ptr1->add(1);

  auto domain2 = new_domain(0);
  auto set_ptr2 = domain2->New<Set<int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(set_ptr2->has(1));

  set_ptr1->remove(1);
  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(!set_ptr2->has(1));
}

CASE(TestSetSerial3) {
  auto domain1 = new_domain(0);
  auto set_ptr1 = domain1->New<Set<int>>(PlatoVariableSyncType::SENDER);
  set_ptr1->add(1);

  auto domain2 = new_domain(0);
  auto set_ptr2 = domain2->New<Set<int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream().clear();

  domain1->serialize_all();

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(set_ptr2->has(1));
}

CASE(TestSetSerial4) {
  auto domain1 = new_domain(0);
  auto set_ptr1 = domain1->New<Set<std::string>>(PlatoVariableSyncType::SENDER);
  set_ptr1->add("1");

  auto domain2 = new_domain(0);
  auto set_ptr2 = domain2->New<Set<std::string>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(set_ptr2->has("1"));
}

CASE(TestStructSerial1) {
  auto domain1 = new_domain(0);
  auto stru1 = domain1->New<TestStruct>(PlatoVariableSyncType::SENDER);
  *stru1->a = 1;

  auto domain2 = new_domain(0);
  auto stru2 = domain2->New<TestStruct>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*stru2->a == 1);
}

CASE(TestStructSerial2) {
  auto domain1 = new_domain(0);
  auto stru1 = domain1->New<TestStruct>(PlatoVariableSyncType::SENDER);
  *stru1->a = 1;

  auto domain2 = new_domain(0);
  auto stru2 = domain2->New<TestStruct>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream().clear();

  domain1->serialize_all();
  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*stru2->a == 1);
}

CASE(TestStructSerial3) {
  auto domain1 = new_domain(0);
  auto arr1 = domain1->New<Array<TestStruct>>(PlatoVariableSyncType::SENDER);
  *arr1->add()->a = 1;

  auto domain2 = new_domain(0);
  auto arr2 = domain2->New<Array<TestStruct>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*arr2->at(0)->a == 1);
}

CASE(TestBool1) {
  auto domain1 = new_domain(0);
  auto b = domain1->New<Bool>();
  *b = true;
  ASSERT_TRUE(*b);
}

CASE(TestBoolSerial1) {
  auto domain1 = new_domain(0);
  auto b1 = domain1->New<Bool>(PlatoVariableSyncType::SENDER);
  *b1 = true;

  auto domain2 = new_domain(0);
  auto b2 = domain2->New<Bool>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();

  domain2->do_sync();

  ASSERT_TRUE(*b2);
}

CASE(TestArrayIterator1) {
  auto domain1 = new_domain(0);
  auto arr1 = domain1->New<Array<Int>>();
  *arr1->add() = 1;
  *arr1->add() = 2;
  *arr1->add() = 3;
  for (auto it = arr1->begin(); it != arr1->end();) {
    if (**it == 2) {
      it = arr1->erase(it);
    } else {
      it++;
    }
  }
  ASSERT_TRUE(arr1->size() == 2);
}

CASE(TestArrayIterator2) {
  auto domain1 = new_domain(0);
  auto arr1 = domain1->New<Array<Int>>(PlatoVariableSyncType::SENDER);
  *arr1->add() = 1;
  *arr1->add() = 2;
  *arr1->add() = 3;

  auto domain2 = new_domain(0);
  auto arr2 = domain2->New<Array<Int>>(PlatoVariableSyncType::RECEIVER);

  domain1->get_stream() >> domain2->get_stream();
  domain2->do_sync();

  domain1->get_stream().clear();
  domain2->get_stream().clear();

  ASSERT_TRUE(arr2->size() == 3);

  for (auto it = arr1->begin(); it != arr1->end();) {
    if (**it == 2) {
      it = arr1->erase(it);
    } else {
      it++;
    }
  }
  ASSERT_TRUE(arr1->size() == 2);

  domain1->get_stream() >> domain2->get_stream();
  domain2->do_sync();

  ASSERT_TRUE(arr2->size() == 2);
}

CASE(TestCopy1) {
  auto domain = new_domain(0);
  auto int_ptr1 = domain->New<Int>(PlatoVariableSyncType::NONE,1);
  auto int_ptr2 = domain->New<Int>(PlatoVariableSyncType::NONE,0);
  int_ptr2->copy(int_ptr1.get());
  ASSERT_TRUE(*int_ptr2 == 1);

  auto float_ptr1 = domain->New<Float>(PlatoVariableSyncType::NONE,1.0f);
  auto float_ptr2 = domain->New<Float>(PlatoVariableSyncType::NONE,.0f);
  float_ptr2->copy(float_ptr1.get());
  ASSERT_TRUE(*float_ptr2 == 1.0f);

  auto str_ptr1 = domain->New<String>(PlatoVariableSyncType::NONE, "1");
  auto str_ptr2 = domain->New<String>(PlatoVariableSyncType::NONE);
  str_ptr2->copy(str_ptr1.get());
  ASSERT_TRUE(*str_ptr2 == "1");

  auto bool_ptr1 = domain->New<Bool>(PlatoVariableSyncType::NONE, true);
  auto bool_ptr2 = domain->New<Bool>(PlatoVariableSyncType::NONE);
  bool_ptr2->copy(bool_ptr1.get());
  ASSERT_TRUE(*bool_ptr2 == true);

  auto arr_ptr1 = domain->New<Array<Int>>(PlatoVariableSyncType::NONE);
  auto arr_ptr2 = domain->New<Array<Int>>(PlatoVariableSyncType::NONE);
  *arr_ptr1->add() = 1;
  *arr_ptr1->add() = 2;
  arr_ptr2->copy(arr_ptr1.get());
  ASSERT_TRUE(arr_ptr2->size() == 2);

  auto set_ptr1 = domain->New<Set<int>>(PlatoVariableSyncType::NONE);
  auto set_ptr2 = domain->New<Set<int>>(PlatoVariableSyncType::NONE);
  set_ptr1->add(1);
  set_ptr1->add(2);
  set_ptr2->copy(set_ptr1.get());
  ASSERT_TRUE(set_ptr2->size() == 2);

  auto map_ptr1 = domain->New<Map<int, Int>>(PlatoVariableSyncType::NONE);
  auto map_ptr2 = domain->New<Map<int, Int>>(PlatoVariableSyncType::NONE);
  *map_ptr1->add(1) = 1;
  *map_ptr1->add(2) = 2;
  map_ptr2->copy(map_ptr1.get());
  ASSERT_TRUE(map_ptr2->size() == 2);
}

FIXTURE_END(TestPlatoVariable)
